from torch.nn.functional import fold, unfold, pad
import matplotlib.pyplot as plt
import warmup_scheduler
from pytorch_lightning.utilities import grad_norm
from torch.nn.functional import grid_sample
import torch
from collections import OrderedDict
import wandb
import random
from diffAug import DiffAug
from torchvision import datasets, transforms
import math
from svm_modules import Support_Vectors
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from utils import get_params_list, upscale_tensor, only_right_loss
import torchmetrics
from model_pool import SimpleCNN, ViT



def create_model(architecture, config):
    if architecture == "resnet":
        model = torchvision.models.resnet50(pretrained=True)
        return model
    elif architecture == "resnet50":
        print("here")
        model = torchvision.models.resnet50(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, config['NUM_CLASSES'])
        return model
    elif architecture == "4conv":
        return SimpleCNN(config['ORIGINAL_SHAPE'], config['NUM_CLASSES'])
    elif architecture == "vit":
        return ViT(num_classes = config['NUM_CLASSES'])

class LitResnet(LightningModule):
    def __init__(self, config):
        super().__init__()
        self.save_hyperparameters(config)
        self.dataset = config['DATASET']

        self.model = create_model(config['ARCHITECTURE'], config)
            
            
        if config.get('TRANSFER', False):
            print("I am transfer learning")
            new_state_dict = OrderedDict()
            for k, v in torch.load(config['REF_MODEL_PATH'])['state_dict'].items():
                # check if the key is fc layer
                if 'fc' in k:
                    continue
                name = k[6:]
                new_state_dict[name] = v
            self.model.load_state_dict(new_state_dict, strict = False)
            # freeze all layers except fc layer
            for name, param in self.model.named_parameters():
                if 'fc' in name:
                    param.requires_grad = True
            self.train_params = self.model.fc.parameters()

        else:
            self.train_params = self.model.parameters()
        if config.get('pretrained', False):
            print("I am pretrained")
            new_state_dict = OrderedDict()
            for k, v in torch.load(config['REF_MODEL_PATH'])['state_dict'].items():
                name = k[6:]
                new_state_dict[name] = v
            self.model.load_state_dict(new_state_dict)
        self.train_params = self.model.parameters()
        self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes = config['NUM_CLASSES'])
        self.f1_score = torchmetrics.F1Score(task='multilabel', num_labels= config['NUM_CLASSES'])
        self.class_name = config['CLASS_NAME']
        self.class_correct = torch.zeros(10) 
        self.class_totals = torch.zeros(10)
        self.test_output = []
        self.test_target = []


    def feature_extractor(self, x):
        out = self.model.feature_extractor(x)

        return out
    def forward(self, x, get_bn = False):
        if get_bn:
            out, bn_diff = self.model(x, get_bn = True)
            return out, bn_diff
        out = self.model(x)
        return out

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        if self.dataset == 'CelebA':
            loss = F.binary_cross_entropy_with_logits(logits, y.float())
            f1_score = self.f1_score(torch.sigmoid(logits), y.int())
            self.log("train_f1", f1_score, prog_bar=True)
        else:
            logits = F.log_softmax(logits, dim=1)
            loss = F.nll_loss(logits, y)
            preds = torch.argmax(logits, dim=1)
            acc = self.accuracy(preds, y)
            self.log("train_acc", acc, prog_bar=True)
        self.log("train_loss", loss)
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        if self.dataset == 'CelebA':
            loss = F.binary_cross_entropy_with_logits(logits, y.float())
            f1_score = self.f1_score(torch.sigmoid(logits), y.int())

            if stage:
                self.log(f"{stage}_loss", loss, prog_bar=True)
                self.log(f"{stage}_f1", f1_score, prog_bar=True)
        else:
            logits = F.log_softmax(logits, dim=1)
            loss = F.nll_loss(logits, y)
            preds = torch.argmax(logits, dim=1)
            acc = self.accuracy(preds, y)

            if stage:
                self.log(f"{stage}_loss", loss, prog_bar=True)
                self.log(f"{stage}_acc", acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def calculate_class_accuracy(self, test_labels, test_predicted):
        class_correct = list(0. for i in range(10))
        class_total = list(0. for i in range(10))

        with torch.no_grad():
            for i in range(len(test_labels)):
                label = test_labels[i]
                pred = test_predicted[i]
                if pred == label:
                    class_correct[label] += 1
                class_total[label] += 1

        class_accuracy = [correct / total if total > 0 else 0 for correct, total in zip(class_correct, class_total)]
        return class_accuracy

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-2,weight_decay=1e-7)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size= 40, gamma=0.1)
        return [self.optimizer], [self.scheduler]



class Litreconstruction(LightningModule):
    def __init__(self, args, config): 
        super().__init__()
        self.is_transfer = config.get('TRANSFER', False)
        config['NUM_SAMPLES'] = args.num_samples
        self.save_hyperparameters(args, config)
        self.Aug = DiffAug(config, args.policy, args.crop_size, args.angle, args.noise_ratio, args.translation_ratio, args.cutout_ratio)

        if config['DATASET'] == 'MNIST':
            config['ORIGINAL_SHAPE'] = (1, 32, 32)

        self.ORIGINAL_SHAPE = config['ORIGINAL_SHAPE']
        self.NUM_SAMPLES = args.num_samples
        self.num_classes = 1000
        self.CLASS_NAME = config['CLASS_NAME']
        print(self.num_classes)
        self.loss_size = 3
        self.DATASET = config['DATASET']

        self.lr_image = args.lr_image
        self.lr_lambda = args.lr_lambda
        self.momentum_image = args.momentum_image
        self.momentum_lambda = args.momentum_lambda
        
        self.stationarity_rate = args.stationarity_rate
        self.temperature = args.temperature
        self.primal_rate = args.primal_rate
        self.aug_stationarity_rate = args.aug_stationarity_rate
        self.aug_primal_rate = args.aug_primal_rate

        self.weight_decay_x = args.weight_decay_x
        self.weight_decay_l = args.weight_decay_l

        self.frozen_net = LitResnet(config)
        new_state_dict = OrderedDict()
        # for k, v in torch.load(config['REF_MODEL_PATH'])['state_dict'].items():
        #     new_state_dict[name] = v
        # self.frozen_net.model.load_state_dict(new_state_dict)
        self.frozen_net = self.frozen_net.cuda()
        self.frozen_net.model.requires_grad = False
        self.grad_scale = 1

        self.UPSCALE_cycle = args.UPSCALE_cycle
        self.temperature1 = args.temperature1
        self.ZIP_RATIO = 1
        self.par_mult = args.par_mult
        self.ZIP_RATE = 1
        self.no_diff = True if self.ZIP_RATE == 1 else False

        self.loss_weight = args.loss_weight
        self.stationarity_weight = args.stationarity_weight

        self.tv_scale = args.tv_scale
        self.alpha_scale = args.alpha_scale
        self.is_celeba = True if config['DATASET'] == 'CelebA' else False

        



        self.trainable_params_4, self.trainable_lambda, self.target, self.output_tensor = get_params_list(config, 4, x_sigma = args.x_normal_std, l_sigma=args.l_normal_std)
        self.trainable_params_2, _, _, _ = get_params_list(config, n = 2, x_sigma = args.x_normal_std, l_sigma=args.l_normal_std)
        self.trainable_params_1, _, _, _ = get_params_list(config, n = 1, x_sigma = args.x_normal_std, l_sigma=args.l_normal_std)
        self.index_pool = torch.arange(self.NUM_SAMPLES * self.num_classes).to(self.frozen_net.device)
        self.x_normal_std = args.x_normal_std
        self.l_normal_std = args.l_normal_std

        self.architecture = config['ARCHITECTURE']

        self.batch_size = min(len(self.trainable_lambda), args.batch_size)

        self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes = config['NUM_CLASSES'])

        self.primal_acc = 0
        self.aug_primal_acc = 0
        self.my_acc = 0

        self.cut_lambda = False
        self.acc_threshold = 0.9
        
    def on_after_backward(self):
        with torch.no_grad():
            if self.ZIP_RATE == 1:
                for p in self.trainable_params_1:
                    p.data.clamp_(0, 1)
                for p in self.trainable_params_2:
                    p.data.clamp_(0, 1)
            elif self.ZIP_RATE == 2:
                for p in self.trainable_params_2:
                    p.data.clamp_(0, 1)
            
            sample_norm = torch.norm(self.trainable_params_1[0])
            sample_grad_norm = torch.norm(self.trainable_params_1.grad[0])
            ratio = sample_grad_norm / sample_norm
            
            wandb.log({"sample_norm": sample_norm, "sample_grad_norm": sample_grad_norm, "ratio": ratio})



    def frozen_model_forward(self, myinput, is_augment=False, one_temp = False):

        if is_augment:
            myinput = self.Aug.DiffAugment(myinput)

        x = myinput.reshape(-1, *self.ORIGINAL_SHAPE)
        x = x.cuda()
        
        if self.DATASET == 'MNIST':
            x = x.expand(-1, 3, -1, -1)
        
        output = self.frozen_net(x)
        if one_temp:
            output = F.log_softmax(output / self.temperature, dim=-1)
        else:
            output = F.log_softmax(output / self.temperature, dim=-1)
        return output

    def get_stationarity(self, samples, samples_lambda, target, is_augment=False):
        
        loss = self.get_lambda_nll(samples, target, samples_lambda, is_augment = is_augment)
        stationarity_loss = 0
        params = self.frozen_net.model.parameters()

        grads = torch.autograd.grad(
            outputs = loss,
            inputs = params,
            grad_outputs=torch.ones_like(
                loss, requires_grad=False, device=self.frozen_net.device
            ).div(self.par_mult),
            create_graph=True,
            retain_graph=True,
        )
        
        param_list = list(self.frozen_net.model.parameters())
        name_list = {p: name for name, p in self.frozen_net.model.named_parameters()}


        param_num = 0

        for j, (p, grad) in enumerate(zip(param_list, grads)):
            if len(p.shape) == 1:
                continue
            layer_name = name_list[p]
            if self.is_transfer == True and 'fc' not in layer_name:
                continue
            dim = 2
            grad /= (torch.norm(grad).detach() + 1e-8)
            p = p / (torch.norm(p, p=dim).detach() + 1e-8)
            grad = grad.sign() * torch.pow(torch.abs(grad) + 1e-8, 0.5)
            grad = grad.sign() * torch.pow(torch.abs(grad) + 1e-8, 0.5)
            grad = grad.sign() * torch.pow(torch.abs(grad) + 1e-8, 0.5)
            grad /= (torch.norm(grad, p=dim).detach() + 1e-8)
            param_num += grad.flatten().shape[0]

            layer_loss = torch.abs((grad + self.grad_scale * p.detach())).pow(dim).sum()

            stationarity_loss += layer_loss
        if math.isnan(stationarity_loss):
            print("stationarity loss is nan")
            exit()
        
        # stationarity_loss = stationarity_loss * 40
        if self.architecture == 'resnet':
            stationarity_loss /= 10
        
        return stationarity_loss * 100000
    def get_total_variation_loss(self, img):

        batch_size, _, height, width = img.size()
        
        tv_h = F.l1_loss(img[:, :, 1:, :], img[:, :, :-1, :])
        tv_w = F.l1_loss(img[:, :, :, 1:], img[:, :, :, :-1])
        
        # Normalize by size for consistency across different input sizes
        return (tv_h + tv_w) / (height * width)
    
    def get_a_norm_loss(self, img):
        batch_size, channel, height, width = img.size()
        img = img.view(batch_size, -1)
        per_size = torch.norm(img, dim = 1, p = 2)
        return torch.sum(per_size)
            
    def get_verification_loss(self, samples):
        loss_verify = 0
        penalty_above_1_corrected = F.elu(-10 * (1 - samples)) + 1 
        penalty_below_0_corrected = F.elu(-100 * samples) + 1
        loss_verify = penalty_above_1_corrected.mean() + penalty_below_0_corrected.mean()
        
        return loss_verify * self.NUM_SAMPLES

    def data_from_index(self, index, is_augment=False, all = False):
        target = self.target[index]
        if self.ZIP_RATE == 1:
            samples = self.trainable_params_1[index]
        if self.ZIP_RATE == 2:
            samples = self.trainable_params_2[index]
        if self.ZIP_RATE == 4:
            samples = self.trainable_params_4[index]
        samples_lambda = self.trainable_lambda[index]
        if self.ZIP_RATE > 1:
            samples = upscale_tensor(samples, self.ZIP_RATE)
        if is_augment:
            samples = self.Aug.DiffAugment(samples)
        return samples, samples_lambda, target 

    def get_lambda_nll(self, samples, target, samples_lambda, is_augment = False, for_primal = False):
        '''
        Get the NLL loss of lambda
        '''
        if for_primal:
            output = self.frozen_model_forward(samples, is_augment, one_temp=True)
            loss = only_right_loss(output, target, samples_lambda.detach(), is_celeba=self.is_celeba)
            if is_augment == True:
                self.aug_primal_acc = 0.99 * self.aug_primal_acc + 0.01 * self.accuracy(torch.argmax(output, dim=-1), target)
                self.log("aug_primal_acc", self.aug_primal_acc)
            else:
                self.primal_acc = 0.99 * self.primal_acc + 0.01 * self.accuracy(torch.argmax(output, dim=-1), target)
                self.log("primal_acc", self.primal_acc)
                self.my_acc = (self.primal_acc + self.aug_primal_acc) /2
                self.log("my_acc", self.my_acc)
        else:
            output = self.frozen_model_forward(samples, is_augment)
            if self.is_celeba:
                mytarget = F.one_hot(target, num_classes = self.num_classes)
                loss = F.binary_cross_entropy_with_logits(output, mytarget.float(), reduction="none")
            else:
                loss = F.nll_loss(output, target, reduction="none") 
                loss *= (torch.nn.ReLU()(samples_lambda).flatten())
        loss = torch.mean(loss)

        if not for_primal:
            self.loss = loss / loss.item()
            wandb.log({"calculate_loss": loss})

        return loss


    
    def get_nll_both(self, samples, target):
        output = self.frozen_model_forward(samples)
        loss = only_right_loss(output, target, is_celeba=self.is_celeba) 
        aug_output = self.frozen_model_forward(samples, is_augment=True)
        aug_loss = only_right_loss(aug_output, target, is_celeba=self.is_celeba)
        self.primal_acc = 0.99 * self.primal_acc + 0.01 * self.accuracy(torch.argmax(output, dim=-1), target)
        self.aug_primal_acc = 0.9 * self.aug_primal_acc + 0.1 * self.accuracy(torch.argmax(aug_output, dim=-1), target)
        self.my_acc = (self.primal_acc + self.aug_primal_acc) /2
        self.log("primal_acc", self.primal_acc)
        self.log("aug_primal_acc", self.aug_primal_acc)
        self.log("my_acc", self.my_acc)

        return loss, aug_loss

    def compute_loss(self):

        index = self.index_pool[torch.randperm(len(self.trainable_params_1))[:self.batch_size]]
            
        stationarity_weight = self.stationarity_weight
        samples, samples_lambda, target = self.data_from_index(index)
        output_logit = torch.exp(self.frozen_model_forward(samples))
        output_prob = torch.max(output_logit, dim = -1).values
        wandb.log({"samples_max": samples.max(), "samples_min": samples.min(), "samples_mean": samples.mean(), "samples_std": samples.std()})

        zero_percentage = torch.sum(samples_lambda < 0) / len(samples_lambda)
        wandb.log({"lambda_percentage": zero_percentage * 100})
        
        stationarity, aug_stationarity = self.get_stationarity(samples, samples_lambda, target, is_augment=False), self.get_stationarity(samples, samples_lambda, target, is_augment=True)
        primal = self.get_lambda_nll(samples, target, samples_lambda, is_augment=False, for_primal = True)
        aug_primal = self.get_lambda_nll(samples, target, samples_lambda, is_augment=True, for_primal = True)

        primal_loss = self.stationarity_rate * primal * self.primal_rate
        aug_primal_loss = self.aug_stationarity_rate * aug_primal * self.primal_rate
        stationarity_loss = self.stationarity_rate * stationarity * stationarity_weight
        aug_stationarity_loss = self.aug_stationarity_rate * aug_stationarity  * stationarity_weight
        tot_loss = self.get_total_variation_loss(samples) * self.tv_scale * 1000
        alpha_loss = self.get_a_norm_loss(samples) * self.alpha_scale
        loss = primal_loss + aug_primal_loss + stationarity_loss + aug_stationarity_loss + tot_loss + alpha_loss + tot_loss


        self.log("total_loss", loss)

        wandb.log({
            "aug_stationarity_loss":  aug_stationarity_loss / loss * 100,
            "stationarity_loss": stationarity_loss / loss * 100,
            "aug_primal_loss": aug_primal_loss / loss * 100,
            "primal_loss": primal_loss / loss * 100,
            "variation_loss" : tot_loss / loss * 100,
            "alpha_loss" : alpha_loss / loss * 100
        })



        loss *= self.batch_size
        loss *= self.loss_weight
        
        return  loss

    def add_noise_and_mult_every(self, multiply_alpha = 0.9):
        if self.ZIP_RATE == 2:
            mins = self.trainable_params_2.view(self.trainable_params_2.size(0), -1).min(dim=1)[0].view(-1, 1, 1, 1)
            maxs = self.trainable_params_2.view(self.trainable_params_2.size(0), -1).max(dim=1)[0].view(-1, 1, 1, 1)


            self.trainable_params_2.data.sub_(mins)
            self.trainable_params_2.data.div_((maxs - mins))

            noise = torch.randn_like(self.trainable_params_2)

            noise_mins = noise.view(noise.size(0), -1).min(dim=1)[0].view(-1, 1, 1, 1)
            noise_maxs = noise.view(noise.size(0), -1).max(dim=1)[0].view(-1, 1, 1, 1)
            
            noise.sub_(noise_mins)
            noise.div_((noise_maxs - noise_mins))
            
            self.trainable_params_2.data.mul_(math.sqrt(multiply_alpha))
            self.trainable_params_2.data.add_(noise * math.sqrt(1-multiply_alpha))
        elif self.ZIP_RATE == 1:
            mins = self.trainable_params_1.view(self.trainable_params_1.size(0), -1).min(dim=1)[0].view(-1, 1, 1, 1)
            maxs = self.trainable_params_1.view(self.trainable_params_1.size(0), -1).max(dim=1)[0].view(-1, 1, 1, 1)


            self.trainable_params_1.data.sub_(mins)
            self.trainable_params_1.data.div_((maxs - mins))

            noise = torch.randn_like(self.trainable_params_1)

            noise_mins = noise.view(noise.size(0), -1).min(dim=1)[0].view(-1, 1, 1, 1)
            noise_maxs = noise.view(noise.size(0), -1).max(dim=1)[0].view(-1, 1, 1, 1)
            
            noise.sub_(noise_mins)
            noise.div_((noise_maxs - noise_mins))
            
            self.trainable_params_1.data.mul_(math.sqrt(multiply_alpha))
            self.trainable_params_1.data.add_(noise * math.sqrt(1-multiply_alpha))

        
    def training_step(self, batch, batch_idx):
        
        if batch_idx % 10 == 0:
            trainable_params = upscale_tensor(self.trainable_params_1, self.ZIP_RATE)
            samples = trainable_params
            output_logit = self.frozen_model_forward(samples)
            negative_entropy = torch.mean(torch.sum(torch.exp(output_logit) * output_logit, dim=1))
            wandb.log({"negative_entropy": negative_entropy})

        loss = torch.tensor(0.0).to(self.frozen_net.device)
               
        loss += self.compute_loss()
        
        if batch_idx == 0:
            trainable_params = self.trainable_params_1
            if self.ZIP_RATE > 1:
                output_logit = torch.exp(self.frozen_model_forward(upscale_tensor(trainable_params, self.ZIP_RATE))) 
            else:
                output_logit = torch.exp(self.frozen_model_forward(trainable_params))
            
            output_max = torch.argmax(output_logit, dim = -1)
            output_prob = torch.max(output_logit, dim = -1).values

        if batch_idx == 0:

            trainable_params = self.trainable_params_1

            output_logit = torch.exp(self.frozen_model_forward(trainable_params))
            output_max = torch.argmax(output_logit, dim = -1)
            output_prob = torch.max(output_logit, dim = -1).values

            # update confident_wrong

            confident_wrong = torch.where(output_prob > 0.95, output_max != self.target, False)
            confident_wrong = torch.nonzero(confident_wrong).squeeze()


            show_size = 50
            

            idx = torch.randperm(len(trainable_params))[:show_size]

            
            
            samples = trainable_params[idx]
            samples_answer = self.target[idx]
            samples_est = output_max[idx]
            samples_prob = output_prob[idx]
            wandb.log({f"samples": [wandb.Image(sample, caption=f"Class {self.CLASS_NAME[samples_answer[i]]}, Estimated {self.CLASS_NAME[samples_est[i]]}, Prob {round(samples_prob[i].item(),2)}") for i, sample in enumerate(samples)]}, commit=False)

            

        return loss * 0.01
    

    def configure_optimizers(self):
        optimizer = torch.optim.SGD([
            {'params': self.trainable_params_1, 'lr': self.lr_image * 10, 'momentum': self.momentum_image, 'weight_decay': self.weight_decay_x},
            {'params': self.trainable_lambda, 'lr': self.lr_lambda, 'momentum': self.momentum_lambda, 'weight_decay': self.weight_decay_l},
        ], lr=1e-2)
  
        return optimizer
    
class LitResnet_SVM(LightningModule):
    def __init__(self, args, config):
        super().__init__()
        self.save_hyperparameters(args, config)
        self.Aug = DiffAug(config, "crop, translation, flip", args.crop_size, args.angle, args.noise_ratio, args.translation_ratio, args.cutout_ratio)
        self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes = config['NUM_CLASSES'])
        self.lr = args.lr_retrain
        self.automatic_optimization = False
        self.DATAPATH = config['DATAPATH']
        self.momentum = args.momentum_retrain
        self.aug_ratio = args.aug_ratio
        self.acc_threshold = 0.9
        self.DATASET = config['DATASET']
        self.batch_size = min(args.batch_size * config['NUM_CLASSES'], config['NUM_SAMPLES'] * config['NUM_CLASSES'])
        self.samaug = True
        self.temperature1 = args.temperature1
        model_path = config['REF_MODEL_PATH']
        self.frozen_net = LitResnet(config)
        new_state_dict = OrderedDict()
        for k, v in torch.load(model_path)['state_dict'].items():
            name = k[6:]
            new_state_dict[name] = v

        self.frozen_net.model.load_state_dict(new_state_dict)
        self.frozen_net = self.frozen_net.cuda()
        self.REF_MODEL_PATH = config['REF_MODEL_PATH']
        self.SUPPORT_PATH = config['SUPPORT_PATH']
        self.config = config
        self.model = create_model(config['ARCHITECTURE'], config)

    def forward(self, x):
        out = self.model(x)
        return out
    def target(self, x):
        out = self.frozen_net.model(x)
        return out
    def train_dataloader(self):
        dataset = Support_Vectors(self.REF_MODEL_PATH, self.SUPPORT_PATH, self.config, self.is_intersec)
        return DataLoader(dataset, batch_size=min(len(dataset), 64), shuffle=True)
    def test_dataloader(self):
        transform = transforms.Compose([transforms.ToTensor()])
        if self.DATASET == 'STL10': 
            datasets.STL10(self.DATAPATH, split='test', download=True, transform=transform)
        elif self.DATASET == 'CIFAR100': 
            test_data = datasets.CIFAR100(self.DATAPATH, train=False, download=True, transform=transform)
        elif self.DATASET == 'CIFAR10': 
            test_data = datasets.CIFAR10(self.DATAPATH, train=False, download=True, transform=transform)
        else: 
            test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
        return DataLoader(test_data, batch_size=1024, shuffle=False)

    def val_dataloader(self):
        transform = transforms.Compose([transforms.ToTensor()])
        if self.DATASET == 'STL10': 
            test_data = datasets.STL10(self.DATAPATH, split='test', download=True, transform=transform)
        elif self.DATASET == 'CIFAR100': 
            test_data = datasets.CIFAR100(self.DATAPATH, train=False, download=True, transform=transform)
        elif self.DATASET == 'CIFAR10': 
            test_data = datasets.CIFAR10(self.DATAPATH, train=False, download=True, transform=transform)
        else: 
            test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
        return DataLoader(test_data, batch_size=1024, shuffle=False)
    
    def compute_loss(self, batch):

        x, y_t, _, feature = batch
        x = x.detach()
        feature = feature.detach()
        ax = self.Aug.DiffAugment(x).detach()
        opt = self.optimizers()
        opt.zero_grad()
        logits = self(x)
        a_logits = self(ax)
        acc = self.accuracy(torch.argmax(logits, dim=-1), y_t)
        self.log("train_acc", acc, prog_bar=True)
        hard_loss = F.cross_entropy(logits/self.temperature1, y_t)
        a_hard_loss = F.cross_entropy(a_logits/self.temperature1, y_t)
        loss = (1-self.aug_ratio) * (hard_loss) + self.aug_ratio * (a_hard_loss)
        return loss


    def training_step(self, batch, batch_idx):
        
        opt = self.optimizers()
        opt.zero_grad()
        loss = self.compute_loss(batch)
        loss.backward()
        opt.step()

        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = F.log_softmax(self(x), dim=-1)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=-1)
        acc = self.accuracy(preds, y)

        if stage:
            self.log(f"{stage}_loss", loss)
            self.log(f"{stage}_acc", acc, prog_bar=True)
            return acc


    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        return optimizer
